import torch as pt
from python_ai.common.xcommon import *

pt.set_printoptions(edgeitems=100)

sep('3x4')
x = pt.arange(1, 3*4+1).view(3, 4)
print(x)

sep('gather dim=1 (meaningful)')
g = pt.gather(x, 1, pt.Tensor([[2, 1, 0], [2, 1, 0], [3, 2, 1]]).long())
print(g)

sep('gather dim=0')
g = pt.gather(x, 0, pt.Tensor([[2, 1, 0], [2, 1, 0], [2, 1, 0]]).long())
print(g)
